# -*- coding: utf-8 -*-
"""
Created on Sun Apr  7 15:07:32 2019
版本1.1:改进速度场比例问题
版本1.2：改进可视化图显示区域问题
版本1.3：改进空间分布边缘非圆问题,速度场比例问题
版本1.4：改进斜高斯问题
"""
#代码段说明
#输入：无
#输出:带标签四维度模拟数据、数据可视化图
#功能：实现模拟数据生成与可视化
#注意 votical代替velocity
import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.stats import gaussian_kde
import os
from astropy.io import fits
from pandas import  DataFrame

#设置模拟数据超参数


def generate_cluster(cluster_stars_num, loc_list, scale_list, bias = 0):#生成团星数据
    raj = np.random.normal(loc = loc_list[0],scale = scale_list[0]*random.uniform(0.5,1.5),size=[cluster_stars_num])
    dej = np.random.normal(loc = loc_list[1],scale = scale_list[1]*random.uniform(0.5,1.5),size=[cluster_stars_num])
    #团和非团的参数配比是否合适？
    pmra = np.random.normal(loc = loc_list[2]+random.uniform(-12,12)*bias,scale = scale_list[2]*random.uniform(0.05,0.15),size=[cluster_stars_num])
    pmde = np.random.normal(loc = loc_list[3]+random.uniform(-12,12)*bias,scale = scale_list[3]*random.uniform(0.05,0.15),size=[cluster_stars_num])
    label = np.zeros([cluster_stars_num])
    cluster_star = np.transpose(np.array([raj, dej, pmra, pmde, label]))
    return cluster_star

    
def generate_nocluster(nocluster_stars_num, loc_list, scale_list):#生成非团星数据，标签为1
    theta = np.random.uniform(0,2*np.pi,size = [nocluster_stars_num])
    radius = 900
    r = np.random.uniform(0, radius,size = [nocluster_stars_num])
    raj = np.sin(theta)*(r**0.5)
    dej = np.cos(theta)*(r**0.5)
    pmra = np.random.normal(loc = loc_list[2], scale = scale_list[2]*random.uniform(0.5,4), size=[nocluster_stars_num])
    pmde = np.random.normal(loc = loc_list[3], scale = scale_list[3]*random.uniform(0.5,4), size=[nocluster_stars_num])
    a = np.random.uniform(0, 3.14)#随机旋转角度
    tpmra = pmra*np.cos(a) - pmde*np.sin(a)
    tpmde = pmra*np.sin(a) + pmde*np.cos(a)
    label = np.zeros([nocluster_stars_num])+1
    nocluster_star = np.transpose(np.array([raj, dej, tpmra, tpmde, label]))
    return nocluster_star

def data_predeal(cluster, nocluster):#数据拼接与洗牌
    mix_stars = np.random.permutation(np.concatenate([cluster, nocluster]))
    return mix_stars

def plot_location(mix_stars):#绘制位置散点图
    plt.scatter(mix_stars[:,0],mix_stars[:,1],c = mix_stars[:,4], s = 10)
    
def plot_votical(mix_stars):#绘制速度散点图
    plt.scatter(mix_stars[:,2],mix_stars[:,3],c = mix_stars[:,4], s = 10)

def gauss_kde_location(mix_stars, loc_value):#实现位置场的核密度估计
    loc = np.array([mix_stars[:,0],mix_stars[:,1]])
    kde = gaussian_kde(loc)
    
    xgrid = np.linspace(loc_value[0], loc_value[1], 60)
    ygrid = np.linspace(loc_value[2], loc_value[3], 60)
    xgrid,ygrid = np.meshgrid(xgrid,ygrid)
    z = kde.evaluate(np.vstack([xgrid.ravel(),ygrid.ravel()]))
    return z,xgrid

    
def gauss_kde_votical(mix_stars, pm_value):#实现速度场的核密度估计
    loc = np.array([mix_stars[:,2],mix_stars[:,3]])
    kde = gaussian_kde(loc)
    
    xgrid = np.linspace(pm_value[0],pm_value[1],60)
    ygrid = np.linspace(pm_value[2],pm_value[3],60)
    xgrid,ygrid = np.meshgrid(xgrid,ygrid)
    z = kde.evaluate(np.vstack([xgrid.ravel(),ygrid.ravel()]))
    return z, xgrid

def gauss_kde_location_plt(z,xgrid,loc_value):#绘制空间位置核密度估计图
#    plt.axis([mix_stars,raj_min,raj_max,dej_min,dej_max])
    plt.imshow(z.reshape(xgrid.shape),origin='lower', aspect='auto',extent = loc_value,cmap='Blues')
#    cb = plt.colorbar()
#    cb.set_label('density')
    
def gauss_kde_votical_plt(z,xgrid, pm_value):#绘制速度场核密度估计图
    plt.imshow(z.reshape(xgrid.shape),origin='lower', aspect='auto',extent = pm_value,cmap='Reds')
#    cb = plt.colorbar()
#    cb.set_label('density')
    
def cluster_subplot(mix_stars, loc_value, pm_value):#自适应确定绘图范围，将各可视化化图片用子图形式集中展现
    plt.figure(figsize=(10, 10))
    plt.subplot(2,2,1)
    plot_location(mix_stars)
    plt.subplot(2,2,2)
    gauss_kde_location_plt(gauss_kde_location(mix_stars,loc_value)[0],gauss_kde_location(mix_stars,loc_value)[1],loc_value)
    plt.subplot(2,2,3)
    plot_votical(mix_stars)
    plt.subplot(2,2,4)
    gauss_kde_votical_plt(gauss_kde_votical(mix_stars,pm_value)[0],gauss_kde_votical(mix_stars, pm_value)[1], pm_value)
    
def cluster_loc_plotsave(save_path, mix_stars, loc_value):#各可视化化图片保存
    plt.axis('off')
    fig = plt.gcf()
    fig.set_size_inches(7.0/3,7.0/3) #dpi = 300, output = 700*700 pixels
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    gauss_kde_location_plt(gauss_kde_location(mix_stars, loc_value)[0],gauss_kde_location(mix_stars, loc_value)[1], loc_value)
    plt.savefig(save_path)

def cluster_vel_plotsave(save_path, mix_stars, pm_value):#将各可视化化图片保存
    plt.axis('off')
    fig = plt.gcf()
    fig.set_size_inches(7.0/3,7.0/3) #dpi = 300, output = 700*700 pixels
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    gauss_kde_votical_plt(gauss_kde_votical(mix_stars, pm_value)[0],gauss_kde_votical(mix_stars, pm_value)[1], pm_value)
    plt.savefig(save_path)

def mockdata_save(i,mix_stars):
    np.savetxt('/content/drive/Colab Notebooks/dataset/new_veldata/txt/'+str(i)+'.txt',mix_stars, fmt = '%8e')
    
def generate_main(num, save_path, control=1, bias=0):#模拟数据生成功能主函数
    if control==1:
            name='cluster'
    else:
            name='no_cluster'
    mkdir(save_path + name + '\\loc_fig\\')
    mkdir(save_path + name + '\\pm_fig\\')
    
    for i in range(num):
        cluster_stars_num = int(random.uniform(150,200)*control)
        nocluster_stars_num = int(random.uniform(2000,3000))
        loc_list = (np.random.random([8])-0.5)*5
        scale_list = (np.random.random([8])+1)*5
        cluster = generate_cluster(cluster_stars_num, loc_list, scale_list, bias)
        nocluster = generate_nocluster(nocluster_stars_num, loc_list, scale_list)
        mix_stars = data_predeal(cluster, nocluster)
        loc_value, pm_value = get_data_config(mix_stars, rate = 3)
        #mockdata_save(i,mix_stars)
        gauss_kde_location_plt(gauss_kde_location(mix_stars,loc_value)[0],gauss_kde_location(mix_stars,loc_value)[1],loc_value)
        gauss_kde_votical_plt(gauss_kde_votical(mix_stars,pm_value)[0],gauss_kde_votical(mix_stars,pm_value)[1],pm_value)
        loc_save_path = save_path + name + '\\loc_fig\\' + str(i) + '.jpg'
        pm_save_path = save_path + name + '\\pm_fig\\' + str(i) + '.jpg'
        cluster_loc_plotsave(loc_save_path, mix_stars, loc_value)
        cluster_vel_plotsave(pm_save_path, mix_stars, pm_value)
        

def plot(label = 1, bias = 0):
    cluster_stars_num = int(random.uniform(150,151))*label
    nocluster_stars_num = int(random.uniform(2999,3000))#juda
    loc_list = (np.random.random([8])-0.5)*5
    scale_list = (np.random.random([8])+1)*5
    cluster = generate_cluster(cluster_stars_num, loc_list, scale_list, bias = bias)
    nocluster = generate_nocluster(nocluster_stars_num, loc_list, scale_list)
    mix_stars = data_predeal(cluster, nocluster)
    loc_value, pm_value = get_data_config(mix_stars)
    cluster_subplot(mix_stars,loc_value, pm_value)
    plt.show()
    #mix_stars1 = mix_stars[mix_stars[:,4]==0]
    #plt.scatter(mix_stars1[:,2],mix_stars1[:,3], c = mix_stars1[:,4])
    #plt.show()

def plot_example(num, label=1, bias=0):
    '''
    label：1为有团， 0为无团
    bias:  在速度场上团速度中心与场星速度中心的偏离程度
    '''
    for i in range(num):
        plot(label=label, bias = bias)
        #plt.savefig('./example/'+str(i)+'.jpg')

def get_data_config(data_deal, rate = 3):#注意此参数仅供绘图使用
    '''
    此参数禁止用于数据保存
    rate 用于控制速度场显示范围，rate*标准差
    '''
    raj_max = np.max(data_deal[:,0])
    raj_min = np.min(data_deal[:,0])
    dej_max = np.max(data_deal[:,1])
    dej_min = np.min(data_deal[:,1])
    loc_value = [raj_min, raj_max, dej_min, dej_max]
    pm_mean = [np.mean(data_deal[:,2]), np.mean(data_deal[:,3])]
    pmra_std = np.std(data_deal[:,2],ddof = 1)
    pmde_std = np.std(data_deal[:,3],ddof = 1)
    pm_fa_mean = (pmra_std+pmde_std)/2
    pmra_max_value = pm_mean[0]+rate*pm_fa_mean
    pmra_min_value = pm_mean[0]-rate*pm_fa_mean
    pmde_max_value = pm_mean[1]+rate*pm_fa_mean
    pmde_min_value = pm_mean[1]-rate*pm_fa_mean
    pm_value = [pmra_min_value, pmra_max_value, pmde_min_value, pmde_max_value]
    return loc_value, pm_value

def pointcycle(point_num,radius):#改进算法
    theta = np.random.uniform(0,2*np.pi,size = [point_num])
    r = np.random.uniform(0, radius,size = [point_num])
    x = np.sin(theta)*(r**0.5)
    y = np.cos(theta)*(r**0.5)   
    plt.plot(x, y,'.', color = "black")
'''
以下为高维度数据处理与可视化函数
'''
#*********************************************************************************
def CMD_plot(data):# 绘制cmd图
    cluster = data[data[:, 4] == 0]
    plt.scatter(cluster[:,7],-cluster[:,6],s=10, c = 'blue')
    plt.show()
    
def CMD_plot_save(data, save_path):
    #plt.figure(figsize=(6, 18))
    #plt.subplot(1,3,1)
    #plt.scatter(data[:,7],-data[:,6], s=10, c = 'black')
    cluster = data[data[:, 4] != -1]
    plt.scatter(cluster[:,7],-cluster[:,6],s=10, c = cluster[:,4])
    plt.savefig(save_path)
    plt.show()
    
def CMD_subplot_save(data, save_path):
    cluster = data[data[:, 4] != -1]
    plt.figure(figsize=(18, 6))
    plt.subplot(1,3,1)
    plt.scatter(data[:,0],data[:,1],s=10, c ='Black')
    plt.scatter(cluster[:,0],cluster[:,1],s=10, c ='Blue')
    plt.subplot(1,3,2)
    plt.scatter(data[:,2],data[:,3],s=10, c ='Black')
    plt.scatter(cluster[:,2],cluster[:,3],s=10, c ='Blue')
    plt.subplot(1,3,3)
    plt.scatter(data[:,7],-data[:,6],s=10, c ='Black')
    plt.scatter(cluster[:,7],-cluster[:,6],s=10, c ='Blue')
    plt.savefig(save_path)
    plt.show()
    
    
    
def mkdir(path):

	folder = os.path.exists(path)

	if not folder:                   #判断是否存在文件夹如果不存在则创建为文件夹
		os.makedirs(path)            #makedirs 创建文件时如果路径不存在会创建这个路径
		print ("---  new folder...  ---")
		print ("---  OK  ---")

	else:
		print ("---  There is this folder(it had already been build)!  ---")
        
#********************************读fits*****************************#       
def get_fits_to_array(file_path):#读取fits文件写为矩阵
    data = fits.open(file_path)
    dat = np.array(data[1].data)
    data_arr = np.zeros([dat.size, 8])
    for i in range (dat.size):
        data_arr[i][0] = dat[i][0]
        data_arr[i][1] = dat[i][1]     
        data_arr[i][2] = dat[i][10]#pmRA
        data_arr[i][3] = dat[i][12]#pmDE
        data_arr[i][5] = dat[i][8]#plx
        data_arr[i][6] = dat[i][17]#Gamg
        data_arr[i][7] = dat[i][27]#BP-RP
    data_deal = DataFrame(data_arr)  
    df = data_deal.dropna() # to dropout the line include nan
    data_deal = np.array(df)
    return data_deal